GitHub 您所在的位置:网站首页 torch focal loss GitHub

GitHub

2024-07-16 14:59| 来源: 网络整理| 查看: 265

PolyLoss in Pytorch

PolyLoss implementation in Pytorch as described in: [Leng et al. 2022] PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions

Both Poly-Cross-Entropy and Poly-Focal losses are provided.

Examples import torch # Poly1 Cross-Entropy Loss # classification task batch_size = 10 num_classes = 5 logits = torch.rand([batch_size, num_classes]) labels = torch.randint(0, num_classes, [batch_size]) loss = Poly1CrossEntropyLoss(num_classes=num_classes, reduction='mean') out = loss(logits, labels) out.backward() # optimizer.step() # Poly1 Focal Loss ## Case 1. labels hold class ids # batch_size, num_classes, height, width B, num_classes, H, W = 2, 3, 4, 7 logits = torch.rand([B, num_classes, H, W]) labels = torch.randint(0, num_classes, [B, H, W]) # optional, class-wise weights, shape must be broadcastable to [B, num_classes, H, W] # put 5 times more weight to class id 2 pos_weight = torch.tensor([1., 1., 5.]).reshape([1, num_classes, 1, 1]) loss = Poly1FocalLoss(num_classes=num_classes, reduction='mean', label_is_onehot=False, pos_weight=pos_weight) out = loss(logits, labels) # out.backward() # optimizer.step() ## Case 2. labels are one-hot or multi-hot (in case of multi-label task) encoded # batch_size, num_classes, height, width B, num_classes, H, W = 2, 3, 4, 7 logits = torch.rand([B, num_classes, H, W]) labels = torch.rand([B, num_classes, H, W]) # labels are of same shape as logits # optionally provide class-wise weights, shape must be broadcastable to [B, num_classes, H, W] # put 5 times more weight to class id 2 pos_weight = torch.tensor([1., 1., 5.]).reshape([1, num_classes, 1, 1]) # weight tensor shape [1, num_classes, 1, 1] is broadcastable to [B, num_classes, H, W] loss = Poly1FocalLoss(num_classes=num_classes, reduction='mean', label_is_onehot=True, pos_weight=pos_weight) out = loss(logits, labels) # out.backward() # optimizer.step() Parameters Poly1CrossEntropyLoss num_classes, (int) - Number of classes epsilon, (float), (Default=1.0) - PolyLoss epsilon reduction, (str), (Default='none') - apply reduction to the output, one of: none | sum | mean weight, (torch.Tensor), (Default=None) - manual rescaling weight for each class, passed to Cross-Entropy loss Poly1FocalLoss num_classes, (int) - Number of classes epsilon, (float), (Default=1.0) - PolyLoss epsilon alpha, (float), (Default=0.25) - Focal loss alpha gamma, (float), (Default=2.0) - Focal loss gamma reduction, (str), (Default='none') - apply reduction to the output, one of: none | sum | mean weight, (torch.Tensor), (Default=None) - manual rescaling weight given to the loss of each batch element, passed to underlying binary_cross_entropy loss (*) pos_weight, (torch.Tensor), (Default=None) - weight of positive examples, passed to underlying binary_cross_entropy loss (*) label_is_onehot, (bool), (Default=False) - set to True if labels are one-hot (or multi-hot) encoded

* Check formulas in the documentation page for BCEWithLogitsLoss to understand how weight (w_n) and pos_weight (p_c) parameters are plugged into the loss function and how they affect the loss. Detailed explanation coming soon. Further discussions can be found in this and this threads.

Requirements Python 3.6+ Pytorch 1.1+


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有